package com.miya.demo.support.websocket;

import lombok.extern.slf4j.Slf4j;
import okhttp3.WebSocket;
import org.springframework.stereotype.Component;

import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * websocket服务
 *
 * @author CaiXiaowei
 * @date 2023/1/17
 */
@Slf4j
@ServerEndpoint(value = "/websocket/{userId}")
@Component
public class WebSocketServer {

    /**
     * 用户id
     */
    private String userId;

    /**
     * 会话: 与某个客户端的连接会话，需要通过它来给客户端发送数据
     */
    private Session session;

    /**
     * 在线数
     */
    private static int onlineCount;

    /**
     * 客户端对应的WebSocketServer对象
     */
    private static ConcurrentHashMap<String, WebSocketServer> webSocketMap = new ConcurrentHashMap<>();

    /**
     * 用来存在线连接数
     */
    private static final ConcurrentHashMap<String, Session> sessionPool = new ConcurrentHashMap<>();

    /**
     * 链接成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam(value = "userId") String userId) {
        try {
            this.session = session;
            webSocketMap.put(userId, this);
            sessionPool.put(userId, session);
            // 在线人数+1
            addOnlineCount();
            log.info("websocket消息: 连接数:{}, 用户数:{}", onlineCount, webSocketMap.size());
        } catch (Exception e) {

        }
    }

    /**
     * 收到客户端消息后调用的方法
     */
    @OnMessage
    public void onMessage(String message) {
        log.info("websocket消息: 收到客户端消息:" + message);
    }

    /**
     * 发送单个消息
     *
     * @param userId  用户id
     * @param message 消息
     */
    public void sendOneMessage(String userId, String message) {
        Session session = sessionPool.get(userId);
        if (session != null && session.isOpen()) {
            try {
                log.info("websocket消: 单点消息:" + message);
                session.getAsyncRemote().sendText(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    @OnError
    public void onError(Session session, Throwable error) {
        log.error("ws异常,userId:{}, error:{}", userId, error);
    }

    public synchronized int getOnlineCount() {
        return onlineCount;
    }

    public synchronized void addOnlineCount() {
        onlineCount++;
    }

    public synchronized void subOnlineCount() {
        onlineCount--;
    }
}
